import copy
import math
import random
import time
import scipy.cluster.vq as vq
import sys
import matplotlib.pyplot as plt
import matplotlib.animation as animation
import decimal
import numpy as np
import scipy.spatial.distance as dist

# 用于初始化隶属度矩阵U
global MAX
MAX = 10000.0
# 用于结束条件
# global Epsilon
# Epsilon = 0.00000001


def import_data_format_iris(file):
    """
    格式化数据，前四列为data，最后一列为cluster_location
    数据地址 http://archive.ics.uci.edu/ml/machine-learning-databases/iris/
    """
    data = []
    cluster_location = []
    with open(str(file), 'r') as f:
        for line in f:
            current = line.strip().split(",")
            current_dummy = []
            for j in range(0, len(current) - 1):
                current_dummy.append(float(current[j]))
            j += 1
            if current[j] == "Iris-setosa\n":
                cluster_location.append(0)
            elif current[j] == "Iris-versicolor\n":
                cluster_location.append(1)
            else:
                cluster_location.append(2)
            data.append(current_dummy)
    print("加载数据完毕")
    return data, cluster_location


def randomise_data(data):
    """
    该功能将数据随机化，并保持随机化顺序的记录
    """
    order = list(range(0, len(data)))
    random.shuffle(order)
    new_data = [[] for i in range(0, len(data))]
    for index in range(0, len(order)):
        new_data[index] = data[order[index]]
    return new_data, order


def de_randomise_data(data, order):
    """
    此函数将返回数据的原始顺序，将randomise_data()返回的order列表作为参数
    """
    new_data = [[] for i in range(0, len(data))]
    for index in range(len(order)):
        new_data[order[index]] = data[index]
    return new_data


def print_matrix(list):
    """
    以可重复的方式打印矩阵
    """
    for i in range(0, len(list)):
        print(list[i])


def initialise_U(data, cluster_number):
    """
    这个函数是隶属度矩阵U的每行加起来都为1. 此处需要一个全局变量MAX.
    """
    global MAX
    U = []
    for i in range(0, len(data)):
        current = []
        rand_sum = 0.0
        for j in range(0, cluster_number):
            dummy = random.randint(1, int(MAX))
            current.append(dummy)
            rand_sum += dummy
        for j in range(0, cluster_number):
            current[j] = current[j] / rand_sum
        U.append(current)
    return U

def end_conditon(U, U_old, Epsilon=1E-6):
    """
    结束条件。当U矩阵随着连续迭代停止变化时，触发结束
    """
    for i in range(0, len(U)):
        for j in range(0, len(U[0])):
            if abs(U[i][j] - U_old[i][j]) > Epsilon:
                return False
    return True


def normalise_U(U):
    """
    在聚类结束时使U模糊化。每个样本的隶属度最大的为1，其余为0
    """
    U = np.array(U)
    m,n = U.shape
    label = np.zeros(m)
    for i in range(0, m):
        maximum = max(U[i,:])
        for j in range(0, n):
            if U[i][j] != maximum:
                label = j
            else:
                break
    return label


# m的最佳取值范围为[1.5，2.5]
def fuzzy(data, cluster_number, m=2, maximum=100, minimum=1E-6,thresh='euclidean', option=None):
    """
    这是主函数，它将计算所需的聚类中心，并返回最终的归一化隶属矩阵U.
    参数是：簇数(cluster_number)和隶属度的因子(m)
    """
    # 初始化隶属度矩阵U
    if option is not None:
        m = option['m']
        maximum = option['maximum']
        minimum = option['minimum']
        thresh = option['thresh']

    U = initialise_U(data, cluster_number)
    # print_matrix(U)
    # 循环更新U
    itor = 0
    while itor < maximum:
        # 创建它的副本，以检查结束条件
        U_old = copy.deepcopy(U)
        # 计算聚类中心
        C = []
        for j in range(0, cluster_number):
            current_cluster_center = []
            for i in range(0, len(data[0])):
                dummy_sum_num = 0.0
                dummy_sum_dum = 0.0
                for k in range(0, len(data)):
                    # 分子
                    dummy_sum_num += (U[k][j] ** m) * data[k][i]
                    # 分母
                    dummy_sum_dum += (U[k][j] ** m)
                    # 第i列的聚类中心
                current_cluster_center.append(dummy_sum_num / dummy_sum_dum)
                # 第j簇的所有聚类中心
            C.append(current_cluster_center)




            # 创建一个距离向量, 用于计算U矩阵。
        # distance_matrix = []
        # for i in range(0, len(data)):
        #     current = []
        #     for j in range(0, cluster_number):
        #         current.append(distance(data[i], C[j]))
        #
        #     distance_matrix.append(current)
        distance_matrix = (dist.cdist(data, C, thresh))
            # 更新U
        for j in range(0, cluster_number):
            for i in range(0, len(data)):
                dummy = 0.0
                for k in range(0, cluster_number):
                    # 分母
                    dummy += (distance_matrix[i][j] / distance_matrix[i][k]) ** (2 / (m - 1))
                U[i][j] = 1 / dummy

        if end_conditon(U, U_old, minimum):
            print("结束聚类")
            break
    print("标准化 U")
    # U = normalise_U(U)
    return U, C


def checker_iris(final_location):
    """
    和真实的聚类结果进行校验比对
    """
    right = 0.0
    for k in range(0, 3):
        checker = [0, 0, 0]
        for i in range(0, 50):
            for j in range(0, len(final_location[0])):
                if final_location[i + (50 * k)][j] == 1:
                    checker[j] += 1
        right += max(checker)
        print(right)
    answer = right / 150 * 100
    return "准确度：" + str(answer) + "%"


if __name__ == '__main__':
    # 加载数据
    # 训练x
    x1 = np.random.random((200, 2))*3 + 5
    x2 = np.hstack((np.random.random((100, 1))+2, np.random.random((100, 1))*2-4))+10
    x3 = np.hstack((np.random.random((100, 1))*2, np.random.random((100, 1))*2-2))+10
    x = np.vstack((x1,x2,x3))
    # 训练结果y
    y = np.zeros(400)
    # y[0:200] = 0
    y[200:300] = 1
    y[300:400] = 2
    data = x
    cluster_location = y
    # data, cluster_location = import_data_format_iris("iris.txt")
    # print_matrix(data)

    # 随机化数据
    data, order = randomise_data(data)
    # print_matrix(data)

    start = time.time()
    # 现在我们有一个名为data的列表，它只是数字
    # 我们还有另一个名为cluster_location的列表，它给出了正确的聚类结果位置
    # 调用模糊C均值函数
    U,C = fuzzy(data, 3, 2)

    # 还原数据
    # final_location = de_randomise_data(final_location, order)
    # print_matrix(final_location)

    # 准确度分析
    # print(checker_iris(final_location))
    print("用时：{0}".format(time.time() - start))
    print(U)
